from pyspark.ml.recommendation import ALS
from pyspark.ml.recommendation import ALSModel
from pyspark import SparkContext
from pyspark.sql import SparkSession
from pyspark.sql.functions import lit
import json


#if __name__ == "main":


def GenreList(sc, spark, trainData):
    genre_list = trainData.groupBy('artist').count().orderBy('count',
                            ascending = False).rdd.map(lambda v: v.artist).take(10)
    return genre_list

list = GenreList(sc,spark,trainData)


#分析总收听人数前十的类型的艺术家
def GenYearSales(sc, spark, artistByID, list):
    #过滤出类型为总销量前五的专辑，将相同类型、相同年份的专辑的销量相加，并进行排序。
    sss = trainData.groupBy('artist').count().orderBy('count',ascending = False).take(10)
    for j in range(len(sss)):
        i = sss[j]
        print(i.artist)
        name = artistByID.filter(artistByID['artist'] == i.artist).select('name').collect()[0]
        i = i.asDict()
        print(name)
        i.update({'artist':name.name })
        print(i)
        sss[j] = i   
    f = open('/usr/local/spark/test/code/static/data/genre-year-sales.json', 'w')
    f.write(json.dumps(sss))
    f.close()


artist_list = trainData.groupBy('artist').sum("count").orderBy('sum(count)',ascending = False).take(10)
        

def genreSales(sc, spark, artistByID,artist_list):
    for j in range(len(artist_list)):
        i = artist_list[j]
        #print(i.artist)
        name = artistByID.filter(artistByID['artist'] == i.artist).select('name').collect()[0]
        i = i.asDict()
        #print(name)
        i.update({'artist':name.name })
        print(i)
        artist_list[j] = i 
    f = open('/usr/local/spark/test/code/static/data/genre-sales.json', 'w')
    f.write(json.dumps(artist_list))
    f.close()





def makeRecommendations(model,userID,number):
    toRecommend = modelnew.itemFactors.selectExpr("id as artist").withColumn("user",lit(userID))
    toRecommend2 = toRecommend.withColumn("artist",toRecommend['artist'].cast("Int")).withColumn("user",toRecommend['user'].cast('Int'))
    toRecommend2.printSchema()
    www = modelnew.transform(toRecommend2).select("artist","prediction").orderBy('prediction',ascending = False).take(10)
    return www

def artistPredict(userID):
    recommend = makeRecommendations(modelnew,userID,10)
    www = recommend
    for j in range(len(www)):
        i = www[j]
        #print(i.artist)
        name = artistByID.filter(artistByID['artist'] == str(i.artist)).select('name').collect()[0]
        i = i.asDict()
        #print(name)
        i.update({'name':name.name })
        www[j] = i
    f = open('/usr/local/spark/test/code/static/data/predict.json', 'w')
    f.write(json.dumps(www))
    f.close()









if __name__ == "__main__":
    sc = SparkContext('local','test')
    sc.setLogLevel("WARN")


    spark = SparkSession.builder.getOrCreate()
    modelnew = ALSModel.load("/usr/local/spark/Model/modelnew")
    artistByID = spark.read.csv("/usr/local/spark/Model/artistByID").toDF("artist","name")
    trainData = spark.read.csv("/usr/local/spark/Model/trainData").toDF("user","artist","count")

    trainData.cache()
    artistByID.cache()
    trainData= trainData.withColumn('count',trainData['count'].cast('int'))
    trainData.printSchema()
    genreSales(sc, spark, artistByID,artist_list)
    GenYearSales(sc, spark, artistByID, list)
    
    artistPredict(1000112)




















